# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR model and criterion classes.
"""
import torch
from torch import nn
from torch.autograd import Variable
from .backbone import build_backbone
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
from .hatransformer import build_hatransformer # TransformerEncoder, TransformerEncoderLayer

import numpy as np

import IPython
e = IPython.embed


def reparametrize(mu, logvar):
    std = logvar.div(2).exp()
    eps = Variable(std.data.new(std.size()).normal_())
    return mu + std * eps


def get_sinusoid_encoding_table(n_position, d_hid):
    def get_position_angle_vec(position):
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

    return torch.FloatTensor(sinusoid_table).unsqueeze(0)


class DETRVAE(nn.Module):
    """ This is the DETR module that performs object detection """
    def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
        """ Initializes the model.
        Parameters:
            backbones: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            state_dim: robot state dimension of the environment
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         DETR can detect in a single image. For COCO, we recommend 100 queries.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.num_queries = num_queries
        self.camera_names = camera_names
        self.transformer = transformer
        self.encoder = encoder
        hidden_dim = transformer.d_model
        self.action_head = nn.Linear(hidden_dim, state_dim)
        self.is_pad_head = nn.Linear(hidden_dim, 1)
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        if backbones is not None:
            self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
            self.backbones = nn.ModuleList(backbones)
            self.input_proj_robot_state = nn.Linear(14, hidden_dim)
        else:
            # input_dim = 14 + 7 # robot_state + env_state
            self.input_proj_robot_state = nn.Linear(14, hidden_dim)
            self.input_proj_env_state = nn.Linear(7, hidden_dim)
            self.pos = torch.nn.Embedding(2, hidden_dim)
            self.backbones = None

        # encoder extra parameters
        self.latent_dim = 32 # final size of latent z # TODO tune
        self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
        self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
        self.encoder_joint_proj = nn.Linear(14, hidden_dim)  # project qpos to embedding
        self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
        self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq

        # decoder extra parameters
        self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
        self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent

    def forward(self, qpos, image, env_state, actions=None, is_pad=None):
        """
        qpos: batch, qpos_dim
        image: batch, num_cam, channel, height, width
        env_state: None
        actions: batch, seq, action_dim
        """
        is_training = actions is not None # train or val
        bs, _ = qpos.shape

        #### CVAE ENCODER PART ####

        ### Obtain latent z from action sequence
        if is_training:
            # project action sequence to embedding dim, and concat with a CLS token
            action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
            qpos_embed = self.encoder_joint_proj(qpos)  # (bs, hidden_dim)
            qpos_embed = torch.unsqueeze(qpos_embed, axis=1)  # (bs, 1, hidden_dim)
            cls_embed = self.cls_embed.weight # (1, hidden_dim)
            cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
            encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
            encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
            # do not mask cls token
            cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
            is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1)  # (bs, seq+1)
            # obtain position embedding
            pos_embed = self.pos_table.clone().detach()
            pos_embed = pos_embed.permute(1, 0, 2)  # (seq+1, 1, hidden_dim)
            # query model
            encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad) # ONLY TRANSFORMER ENCODER
            encoder_output = encoder_output[0] # take cls output only
            latent_info = self.latent_proj(encoder_output)
            mu = latent_info[:, :self.latent_dim]
            logvar = latent_info[:, self.latent_dim:]
            latent_sample = reparametrize(mu, logvar)
            latent_input = self.latent_out_proj(latent_sample)
        else:
            mu = logvar = None
            latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
            latent_input = self.latent_out_proj(latent_sample)

        #### CVAE DECODER PART

        if self.backbones is not None:
            # Image observation features and position embeddings
            all_cam_features = []
            all_cam_pos = []
            for cam_id, cam_name in enumerate(self.camera_names):
                # features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
                features, pos = self.backbones[cam_id](image[:, cam_id])
                features = features[0] # take the last layer feature
                pos = pos[0]
                all_cam_features.append(self.input_proj(features))
                all_cam_pos.append(pos)
            # proprioception features
            proprio_input = self.input_proj_robot_state(qpos)
            # fold camera dimension into width dimension
            src = torch.cat(all_cam_features, axis=3)
            pos = torch.cat(all_cam_pos, axis=3)
            hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)
            hs = hs[-1]
        else:
            qpos = self.input_proj_robot_state(qpos)
            env_state = self.input_proj_env_state(env_state)
            transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
            hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)
            hs = hs[-1]
        a_hat = self.action_head(hs)
        is_pad_hat = self.is_pad_head(hs)
        return a_hat, is_pad_hat, [mu, logvar]

class JointEmbedder(nn.Module):
    def __init__(self, input_dim, output_dim, num_joints):
        super().__init__()
        # Assuming each joint has the same input dimension and requires the same output dimension
        # self.embedders = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_joints)])
        self.num_joints = num_joints
        self.joint_layer = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        # x expected shape: [batch_size, num_joints, input_dim]
        # Process each joint independently
        embeddings = [self.joint_layer(x[:, i, :]) for i in range(self.num_joints)]
        # Stack all embeddings along the joint dimension
        output = torch.stack(embeddings, dim=1)
        return output
     
class HAT_DETRVAE(nn.Module):
    def __init__(self, backbones, transformer, encoder, additional_input_length, state_dim, num_queries, camera_names):
        super().__init__()
        self.num_queries = num_queries
        self.camera_names = camera_names
        self.transformer = transformer
        self.cvae_encoder = encoder
        self.hidden_dim = transformer.d_model
        hidden_dim = self.hidden_dim

        # self.action_head_left = nn.Linear(hidden_dim, 7)
        # self.action_head_right =  nn.Linear(hidden_dim, 7)

        self.action_head =  nn.Linear(hidden_dim, 7)



        self.is_pad_head = nn.Linear(hidden_dim, 1)
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        if backbones is not None:
            self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
            self.backbones = nn.ModuleList(backbones)
            self.input_proj_robot_state = nn.Linear(7, hidden_dim) # TODO
        else:
            # input_dim = 14 + 7 # robot_state + env_state
            self.input_proj_robot_state = nn.Linear(14, hidden_dim) # TODO
            self.input_proj_env_state = nn.Linear(7, hidden_dim) # TODO
            self.pos = torch.nn.Embedding(2, hidden_dim)
            self.backbones = None

        # encoder extra parameters ## TODO
        self.latent_dim = 32 # final size of latent z # TODO tune
        # self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
        self.cls_input = nn.Embedding(1, hidden_dim)
        # self.encoder_action_proj = nn.Linear(7, hidden_dim) # project action to embedding
        # self.encoder_joint_proj = nn.Linear(7, hidden_dim)  # project qpos to embedding
        # self.encoder_both_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
        # self.encoder_both_joint_proj = nn.Linear(1, hidden_dim)

        # self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
        # self.register_buffer('pos_table', get_sinusoid_encoding_table(1+14+num_queries, hidden_dim)) # [CLS], qpos, a_seq

        # decoder extra parameters
        # self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
        
        self.additional_pos_embed = nn.Embedding(additional_input_length, hidden_dim) # learned position embedding for proprio and latent
        
        # self.query_embed = nn.Embedding(additional_input_length + 300 * len(camera_names), hidden_dim)

        # self.embedders = nn.ModuleList([nn.Linear(1, hidden_dim) for _ in range(7)])

    def expand_qpos_to_tokens(qpos, num_tokens=14, hidden_dim=512):
        # Assuming qpos has shape [batch_size, qpos_feature_count]
        batch_size, feature_count = qpos.shape
        # Create an expanded linear layer that maps each feature to `hidden_dim` * `num_tokens`
        expanded_projection = nn.Linear(feature_count, hidden_dim * num_tokens)
        # Apply projection
        expanded_output = expanded_projection(qpos)  # Shape: [batch_size, hidden_dim * num_tokens]
        # Reshape to [batch_size, num_tokens, hidden_dim]
        output_tokens = expanded_output.view(batch_size, num_tokens, hidden_dim)
        return output_tokens

    def forward(self, qpos, image, env_state, actions=None, is_pad=None):
        """
        qpos: batch, qpos_dim
        image: batch, num_cam, channel, height, width
        env_state: None
        actions: batch, seq, action_dim
        """
        is_training = actions is not None # train or val
        bs, _ = qpos.shape

        #### CVAE ENCODER PART ####

        # ### Obtain latent z from action sequence
        # if is_training:
        #     # project action sequence to embedding dim, and concat with a CLS token
        #     actions_arm1, actions_arm2 = torch.split(actions, 7, dim=2)
        #     qpos_arm1, qpos_arm2 = torch.split(qpos, 7, dim=1)

        #     action_embed = self.encoder_both_action_proj(actions) # (bs, seq, hidden_dim)
        #     qpos_embed = self.encoder_both_joint_proj(qpos.unsqueeze(2))  # (bs, hidden_dim)
        #     # qpos_embed = torch.unsqueeze(qpos_embed, axis=1)  # (bs, 1, hidden_dim)
        #     cls_embed = self.cls_embed.weight # (1, hidden_dim)
        #     cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
        #     encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
        #     encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)



        #     # action_embed_arm1 = self.encoder_action_proj(actions_arm1)
        #     # qpos_embed_arm1 = self.encoder_joint_proj(qpos_arm1)
        #     # qpos_embed_arm1 = torch.unsqueeze(qpos_embed_arm1, axis=1)

        #     # cls_embed_arm1 = self.cls_embed.weight # (1, hidden_dim)
        #     # cls_embed_arm1 = torch.unsqueeze(cls_embed_arm1, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
        #     # encoder_input_arm1 = torch.cat([cls_embed_arm1, qpos_embed_arm1, action_embed_arm1], axis=1) # (bs, seq+1, hidden_dim)
        #     # encoder_input_arm1 = encoder_input_arm1.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
        #     # do not mask cls token


        #     # action_embed_arm2 = self.encoder_action_proj(actions_arm2) # (bs, seq, hidden_dim)
        #     # qpos_embed_arm2 = self.encoder_joint_proj(qpos_arm2)  # (bs, hidden_dim)
        #     # qpos_embed_arm2 = torch.unsqueeze(qpos_embed_arm2, axis=1)  # (bs, 1, hidden_dim)
        #     # cls_embed_arm2 = self.cls_embed.weight # (1, hidden_dim)
        #     # cls_embed_arm2 = torch.unsqueeze(cls_embed_arm2, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
        #     # encoder_input_arm2 = torch.cat([cls_embed_arm2, qpos_embed_arm2, action_embed_arm2], axis=1) # (bs, seq+1, hidden_dim)
        #     # encoder_input_arm2 = encoder_input_arm2.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
        #     # do not mask cls token
        #     cls_joint_is_pad = torch.full((bs, 15), False).to(qpos.device) # False: not a padding
        #     is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1)  # (bs, seq+1)
        #     # obtain position embedding
        #     pos_embed = self.pos_table.clone().detach()
        #     pos_embed = pos_embed.permute(1, 0, 2)  # (seq+1, 1, hidden_dim)
        #     # query model
            
        #     encoder_output = self.cvae_encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad) # ONLY TRANSFORMER ENCODER
        #     # encoder_output_arm1 = self.cvae_encoder(encoder_input_arm1, pos=pos_embed, src_key_padding_mask=is_pad) # ONLY TRANSFORMER ENCODER
        #     # encoder_output_arm2 = self.cvae_encoder(encoder_input_arm2, pos=pos_embed, src_key_padding_mask=is_pad)
            
        #     encoder_output = encoder_output[0] # take cls output only
        #     # encoder_output_arm1 = encoder_output_arm1[0]
        #     # encoder_output_arm2 = encoder_output_arm2[0] 


        #     latent_info = self.latent_proj(encoder_output)
        #     # latent_info_arm1 = self.latent_proj(encoder_output_arm1)
        #     # latent_info_arm2 = self.latent_proj(encoder_output_arm2)

        #     mu = latent_info[:, :self.latent_dim]
        #     # mu_arm1 = latent_info_arm1[:, :self.latent_dim]
        #     # mu_arm2 = latent_info_arm2[:, :self.latent_dim]
            
        #     logvar = latent_info[:, self.latent_dim:]
        #     # logvar_arm1 = latent_info_arm1[:, self.latent_dim:]
        #     # logvar_arm2 = latent_info_arm2[:, self.latent_dim:]

        #     latent_sample = reparametrize(mu, logvar)
        #     # latent_sample_arm1 = reparametrize(mu_arm1, logvar_arm1)
        #     # latent_sample_arm2 = reparametrize(mu_arm2, logvar_arm2)

        #     latent_input = self.latent_out_proj(latent_sample)
        #     # latent_input_arm1 = self.latent_out_proj(latent_sample_arm1)
        #     # latent_input_arm2 = self.latent_out_proj(latent_sample_arm2)
            
        #     latent_input = latent_input.unsqueeze(0)
        #     # latent_input_arm1 = latent_input_arm1.unsqueeze(0)
        #     # latent_input_arm2 = latent_input_arm2.unsqueeze(0)

        #     latent_input_arm1 = None
        #     latent_input_arm2 = None

        #     mu_arm1 = mu_arm2 = logvar_arm1 = logvar_arm2 = None
        # else:
        #     mu = mu_arm1 = mu_arm2 = logvar = logvar_arm1 = logvar_arm2 = None
            
        #     latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
        #     latent_sample_arm1 = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device) # Set to Zero during inference
        #     latent_sample_arm2 = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)

        #     latent_input = self.latent_out_proj(latent_sample)
        #     # latent_input_arm1 = self.latent_out_proj(latent_sample_arm1)
        #     # latent_input_arm2 = self.latent_out_proj(latent_sample_arm2)

        #     latent_input = latent_input.unsqueeze(0)
        #     latent_input_arm1 = None
        #     latent_input_arm2 = None
        #     # latent_input_arm1 = latent_input_arm1.unsqueeze(0)
        #     # latent_input_arm2 = latent_input_arm2.unsqueeze(0)

        #### CVAE DECODER PART

        if self.backbones is not None:
            # Image observation features and position embeddings
            all_cam_features = []
            all_cam_pos = []
            for cam_id, cam_name in enumerate(self.camera_names):
                # features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
                features, pos = self.backbones[cam_id](image[:, cam_id])
                features = features[0] # take the last layer feature
                pos = pos[0]
                all_cam_features.append(self.input_proj(features))
                all_cam_pos.append(pos)
            # proprioception features

            qpos_left = qpos[:, :7].unsqueeze(2)
            qpos_right = qpos[:, 7:].unsqueeze(2)

            joint_embedder = JointEmbedder(1, self.hidden_dim, 7).to(qpos.device)
            embedded_joints_left = joint_embedder(qpos_left)
            embedded_joints_right = joint_embedder(qpos_right)
            # Concatenate left and right proprioceptive embeddings
            proprio_input = torch.cat([embedded_joints_left, embedded_joints_right], axis=1) 
            # fold camera dimension into width dimension
            src = torch.cat(all_cam_features, axis=3)
            pos = torch.cat(all_cam_pos, axis=3)
            # latent_input = None # NO CVAE STRUCTURE
            # latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
            # latent_input = self.latent_out_proj(latent_sample)
            cls_input = self.cls_input.weight
            cls_input = torch.unsqueeze(cls_input, axis=0).repeat(25, bs, 1)

            hs_left, hs_right, attn_weights, attn_weights2 = self.transformer(src, None, self.query_embed.weight, pos, cls_input, proprio_input, len(self.camera_names), self.additional_pos_embed.weight)

        else:
            qpos = self.input_proj_robot_state(qpos)
            env_state = self.input_proj_env_state(env_state)
            transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
            hs_left, hs_right = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)

        # hs_left = hs[:, :7]
        # hs_right = hs[:, 7:]

        a_hat_left = self.action_head(hs_left) # GET LAST LAYER OUTPUT and project to 7 dim.  Why was this 0? 
        a_hat_right = self.action_head(hs_right)
        
        a_hat = torch.cat([a_hat_left, a_hat_right], axis=2)
        a_hat = a_hat.permute(1, 0, 2) # (seq, bs, 7)
        # is_pad_hat_left = self.is_pad_head(hs_left)
        # is_pad_hat_right = self.is_pad_head(hs_right)
        is_pad_hat = None
        # mu = None
        # logvar = None
        cls_1_1 = attn_weights[:,:,0:7].sum()
        cls_2_1 = attn_weights[:,:,14:21].sum()
        cls_3_1 = attn_weights2[:,:,21:26].sum()
        cls_1 = [cls_1_1, cls_2_1, cls_3_1, cls_1_1 + cls_2_1 + cls_3_1]

        cls_1_2 = attn_weights2[:,:,0:7].sum()
        cls_2_2 = attn_weights2[:,:,14:21].sum()
        cls_3_2 = attn_weights2[:,:,21:26].sum()
        cls_2 = [cls_1_2, cls_2_2, cls_3_2, cls_1_2 + cls_2_2 + cls_3_2]

        return a_hat, is_pad_hat, cls_1, cls_2

class CNNMLP(nn.Module):
    def __init__(self, backbones, state_dim, camera_names):
        """ Initializes the model.
        Parameters:
            backbones: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            state_dim: robot state dimension of the environment
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         DETR can detect in a single image. For COCO, we recommend 100 queries.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.camera_names = camera_names
        self.action_head = nn.Linear(1000, state_dim) # TODO add more
        if backbones is not None:
            self.backbones = nn.ModuleList(backbones)
            backbone_down_projs = []
            for backbone in backbones:
                down_proj = nn.Sequential(
                    nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
                    nn.Conv2d(128, 64, kernel_size=5),
                    nn.Conv2d(64, 32, kernel_size=5)
                )
                backbone_down_projs.append(down_proj)
            self.backbone_down_projs = nn.ModuleList(backbone_down_projs)

            mlp_in_dim = 768 * len(backbones) + 14
            self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
        else:
            raise NotImplementedError

    def forward(self, qpos, image, env_state, actions=None):
        """
        qpos: batch, qpos_dim
        image: batch, num_cam, channel, height, width
        env_state: None
        actions: batch, seq, action_dim
        """
        is_training = actions is not None # train or val
        bs, _ = qpos.shape
        # Image observation features and position embeddings
        all_cam_features = []
        for cam_id, cam_name in enumerate(self.camera_names):
            features, pos = self.backbones[cam_id](image[:, cam_id])
            features = features[0] # take the last layer feature
            pos = pos[0] # not used
            all_cam_features.append(self.backbone_down_projs[cam_id](features))
        # flatten everything
        flattened_features = []
        for cam_feature in all_cam_features:
            flattened_features.append(cam_feature.reshape([bs, -1]))
        flattened_features = torch.cat(flattened_features, axis=1) # 768 each
        features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
        a_hat = self.mlp(features)
        return a_hat


def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
    if hidden_depth == 0:
        mods = [nn.Linear(input_dim, output_dim)]
    else:
        mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
        for i in range(hidden_depth - 1):
            mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
        mods.append(nn.Linear(hidden_dim, output_dim))
    trunk = nn.Sequential(*mods)
    return trunk


def build_encoder(args):
    d_model = args.hidden_dim # 256
    dropout = args.dropout # 0.1
    nhead = args.nheads # 8
    dim_feedforward = args.dim_feedforward # 2048
    num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
    normalize_before = args.pre_norm # False
    activation = "relu"

    encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                            dropout, activation, normalize_before)
    encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
    encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

    return encoder


def build_act(args):
    state_dim = 14 # TODO hardcode

    # From state
    # backbone = None # from state for now, no need for conv nets
    # From image
    backbones = []
    # backbone = build_backbone(args)
    # backbones.append(backbone)
    for _ in args.camera_names:
        backbone = build_backbone(args)
        backbones.append(backbone)

    transformer = build_transformer(args)

    encoder = build_encoder(args)

    model = DETRVAE(
        backbones,
        transformer,
        encoder,
        state_dim=state_dim,
        num_queries=args.num_queries,
        camera_names=args.camera_names,
    )

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("number of parameters: %.2fM" % (n_parameters/1e6,))

    return model

def build_hatact(args):
    state_dim = 7 # TODO hardcode
    # From state
    # backbone = None # from state for now, no need for conv nets
    # From image
    backbones = []
    # backbone = build_backbone(args)
    # backbones.append(backbone)
    for _ in args.camera_names:
        backbone = build_backbone(args)
        backbones.append(backbone)

    transformer = build_hatransformer(args)

    encoder = build_encoder(args)

    # if args.dual_latent:
    #     additional_input_length = 12
    # elif args.single_latent:
    #     additional_input_length = 11
    # else:
    #     
    additional_input_length = 26

    model = HAT_DETRVAE(
        backbones,
        transformer,
        encoder,
        additional_input_length=additional_input_length,
        state_dim=state_dim,
        num_queries=args.num_queries,
        camera_names=args.camera_names,
    )

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("number of parameters: %.2fM" % (n_parameters/1e6,))

    return model

def build_cnnmlp(args):
    state_dim = 14 # TODO hardcode

    # From state
    # backbone = None # from state for now, no need for conv nets
    # From image
    backbones = []
    for _ in args.camera_names:
        backbone = build_backbone(args)
        backbones.append(backbone)

    model = CNNMLP(
        backbones,
        state_dim=state_dim,
        camera_names=args.camera_names,
    )

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("number of parameters: %.2fM" % (n_parameters/1e6,))

    return model

